use super::{status::*, Attr, Deadline, Delay, RawTask, RawTaskVTable, TaskRef};
use crate::thread::Semaphore;
use crate::{prelude::*, Error};
use core::alloc::Layout;
use core::cell::Cell;
use core::future::Future;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::pin::Pin;
use core::ptr::{self, NonNull};
use core::sync::atomic::{
    fence,
    Ordering::{self, Acquire, Relaxed, Release},
};
use core::task::{Context, Poll, Waker};
use core::time::Duration;
use hipool::{Allocator, Boxed, PoolAlloc};

enum Body<T: Future> {
    Running(T),
    Return(T::Output),
    Clean,
}

enum Joinable {
    Sync(NonNull<Semaphore>),
    Async(Waker),
    Null,
}

#[cfg(feature = "task_mem_split")]
#[repr(C)]
pub(crate) struct Task<T: Future, A: Allocator> {
    body: Cell<Body<T>>,
    joinable: Cell<Joinable>,
    alloc: ManuallyDrop<A>,
}

#[cfg(not(feature = "task_mem_split"))]
#[repr(C)]
pub(crate) struct Task<T: Future, A: Allocator> {
    base: RawTask,
    body: Cell<Body<T>>,
    joinable: Cell<Joinable>,
    alloc: ManuallyDrop<A>,
}

impl<T: Future> Task<T, PoolAlloc> {
    #[allow(clippy::new_ret_no_self)]
    pub fn new(future: T, attr: &Attr) -> Result<TaskRef, Error> {
        Self::new_in(PoolAlloc, future, attr)
    }
}

impl<T: Future, A: Allocator + Clone> Task<T, A> {
    const TASK_VTABLE: RawTaskVTable = RawTaskVTable {
        release: Self::release,
        poll: Self::poll,
        poll_deadline: Self::poll_deadline,
        output: Self::output,
        join: Self::join,
        tail: Self::tail,
    };

    pub fn new_in(alloc: A, future: T, attr: &Attr) -> Result<TaskRef, Error> {
        if attr.timeout == Duration::MAX {
            Task::new_with(alloc, future, attr)
        } else if attr.delay {
            Task::new_with(alloc, Delay::new(attr.timeout, future), attr)
        } else {
            Task::new_with(alloc, Deadline::new(attr.timeout, future), attr)
        }
    }

    #[cfg(feature = "task_mem_split")]
    fn new_with(alloc: A, future: T, attr: &Attr) -> Result<TaskRef, Error> {
        let raw = Boxed::uninit_in::<RawTask>(alloc.clone())?;
        let task = if attr.tail.size() == 0 {
            Boxed::uninit_in::<Self>(alloc.clone())?
        } else {
            let layout = Layout::new::<Self>().extend(attr.tail).unwrap().0;
            unsafe {
                Boxed::new_buf_in(alloc.clone(), layout)?.cast_unchecked::<MaybeUninit<Self>>()
            }
        };
        let raw = raw.write(RawTask::new(
            &Self::TASK_VTABLE,
            attr,
            task.as_ptr().cast::<()>(),
        ));
        let _ = task
            .write(Self {
                body: Cell::new(Body::Running(future)),
                joinable: Cell::new(Joinable::Null),
                alloc: ManuallyDrop::new(alloc),
            })
            .leak();
        Ok(TaskRef::new(NonNull::from(raw.leak())))
    }

    #[cfg(not(feature = "task_mem_split"))]
    fn new_with(alloc: A, future: T, attr: &Attr) -> Result<TaskRef, Error> {
        let task = if attr.tail.size() == 0 {
            Boxed::new_in(
                alloc.clone(),
                Self {
                    base: RawTask::new(&Self::TASK_VTABLE, attr),
                    body: Cell::new(Body::Running(future)),
                    joinable: Cell::new(Joinable::Null),
                    alloc: ManuallyDrop::new(alloc),
                },
            )?
        } else {
            let layout = Layout::new::<Self>().extend(attr.tail).unwrap().0;
            let task = unsafe {
                Boxed::new_buf_in(alloc.clone(), layout)?.cast_unchecked::<MaybeUninit<Self>>()
            };
            task.write(Self {
                base: RawTask::new(&Self::TASK_VTABLE, attr),
                body: Cell::new(Body::Running(future)),
                joinable: Cell::new(Joinable::Null),
                alloc: ManuallyDrop::new(alloc),
            })
        };
        Ok(TaskRef::new(NonNull::from(&task.leak().base)))
    }
    #[cfg(feature = "task_mem_split")]
    unsafe fn from_raw(task: &RawTask) -> &Self {
        &*(task.future.cast::<Self>())
    }

    #[cfg(feature = "task_mem_split")]
    unsafe fn from_raw_mut(task: &mut RawTask) -> *mut Self {
        task.future.cast::<Self>().cast_mut()
    }

    #[cfg(not(feature = "task_mem_split"))]
    unsafe fn from_raw(task: &RawTask) -> &Self {
        hioff::container_of!(task, Self, base)
    }

    #[cfg(not(feature = "task_mem_split"))]
    unsafe fn from_raw_mut(task: &mut RawTask) -> *mut Self {
        hioff::container_of_mut!(task, Self, base)
    }

    fn exit_with(&mut self, task: &RawTask, ret: Result<T::Output, u8>) {
        let exit_code = match ret {
            Ok(ret) => {
                self.body.set(Body::Return(ret));
                STAT_FINISH
            }
            Err(ret) => ret,
        };
        let status = task.status.cmp_xchg_status(STAT_RUN, exit_code, Release);
        match status {
            Ok(_status) => {}
            Err(STAT_RETURN) => {
                task.status.set_status(exit_code, Release);
                fence(Ordering::Acquire);
                // 可能多次调用output，调用output有可能同这里并发
                // 需要保证仅在output调用中修改joinable的内容
                match unsafe { &*self.joinable.as_ptr() } {
                    Joinable::Async(waker) => waker.wake_by_ref(),
                    _ => unreachable!("expect Joinable::Async"),
                }
            }
            Err(STAT_JOIN) => unsafe {
                task.status.set_status(exit_code, Release);
                fence(Ordering::Acquire);
                // join只可能调用一次，调用后就会阻塞等待这里通知
                // 因此这里修改joinable没有任何风险
                match self.joinable.replace(Joinable::Null) {
                    Joinable::Sync(mut sem) => sem.as_mut().post(),
                    _ => unreachable!("expect Joinable::Sync"),
                }
            },
            Err(_status) => {}
        }
    }

    fn poll(task: &mut RawTask, ctx: &mut Context<'_>) -> Poll<()> {
        let this = unsafe { &mut *Self::from_raw_mut(task) };
        let future = match this.body.get_mut() {
            Body::Running(future) => future,
            _ => unreachable!("error body, should be Running"),
        };
        let pinned = unsafe { Pin::new_unchecked(future) };

        match Future::poll(pinned, ctx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(ret) => {
                this.exit_with(task, Ok(ret));
                Poll::Ready(())
            }
        }
    }

    fn poll_deadline(task: &mut RawTask) {
        let this = unsafe { &mut *Self::from_raw_mut(task) };
        this.exit_with(task, Err(STAT_TIMEOUT));
    }

    #[cfg(feature = "task_mem_split")]
    unsafe fn release(task: &mut RawTask) {
        let this = unsafe { &mut *Self::from_raw_mut(task) };
        let alloc = unsafe { ManuallyDrop::take(&mut this.alloc) };
        unsafe { ptr::drop_in_place(this) };
        unsafe { ptr::drop_in_place(task) };
        unsafe { alloc.release(NonNull::from(this)) };
        unsafe { alloc.release(NonNull::from(task)) };
    }

    #[cfg(not(feature = "task_mem_split"))]
    unsafe fn release(task: &mut RawTask) {
        let this = unsafe { &mut *Self::from_raw_mut(task) };
        let alloc = unsafe { ManuallyDrop::take(&mut this.alloc) };
        unsafe { ptr::drop_in_place(this) };
        unsafe { alloc.release(NonNull::from(this)) };
    }

    unsafe fn exit_code(&self, task: &RawTask, output: *mut ()) -> Result<(), Error> {
        match task.status.status(Relaxed) {
            STAT_FINISH => {
                let output = output.cast::<T::Output>();
                match self.body.replace(Body::Clean) {
                    Body::Return(val) => output.write(val),
                    _ => unreachable!("error body, should be Return"),
                }
                Ok(())
            }
            STAT_TIMEOUT => {
                self.body.replace(Body::Clean);
                Err(Error::new(ETIMEDOUT))
            }
            STAT_ABORT => {
                self.body.replace(Body::Clean);
                Err(Error::new(ECANCELED))
            }
            _ => {
                self.body.replace(Body::Clean);
                Err(Error::default())
            }
        }
    }

    /// abort/join/output都是JoinHandle触发，不可能并发
    /// STAT_JOIN: 阻塞式调用等待执行完毕
    unsafe fn join(task: &RawTask, output: *mut ()) -> Result<(), Error> {
        let this = unsafe { Self::from_raw(task) };
        let sem = Semaphore::new()?;
        this.joinable.set(Joinable::Sync(NonNull::from(&sem)));
        let status = task.status.cmp_xchg_status(STAT_RUN, STAT_JOIN, Release);
        if status.is_ok() {
            sem.wait();
        } else {
            fence(Acquire);
        }
        this.joinable.set(Joinable::Null);
        this.exit_code(task, output)
    }

    /// 需要考虑和Self::poll执行完毕后的并发冲突逻辑, 依靠status的原子交换来保证
    /// 如果已经STAT_FINISH,则保证可以正常读取返回值, 需要Ordering::Acquire, 对应交换失败场景
    /// 否则保证Self::poll完成能获取到有效的this.waker, 需要Ordering::Release, 对应交换成功场景
    unsafe fn output(task: &RawTask, waker: &Waker, output: *mut ()) -> Poll<Result<(), Error>> {
        // 只有JoinHandle会调用这个接口
        // 这里设置的时候一定不会和poll并发冲突
        let this = unsafe { Self::from_raw(task) };
        let joined = match unsafe { &*this.joinable.as_ptr() } {
            Joinable::Async(_) => true,
            Joinable::Null => false,
            _ => unreachable!("Expect Joinable::Async or Joinable::Null"),
        };
        if !joined {
            this.joinable.set(Joinable::Async(waker.clone()));
            // 交换成功需要保证poll一定读取到waker, 需要Release
            // 交换失败且为Finish，一定可以读取到output，需要Acquire
            let _ = task.status.cmp_xchg_status(STAT_RUN, STAT_RETURN, Release);
        }

        if task.status.status(Acquire) == STAT_RETURN {
            Poll::Pending
        } else {
            this.joinable.set(Joinable::Null);
            Poll::Ready(this.exit_code(task, output))
        }
    }

    unsafe fn tail(task: &RawTask, layout: Layout) -> *const () {
        let this = unsafe { Self::from_raw(task) } as *const Self;
        let offset = Layout::new::<Self>().extend(layout).unwrap().1;
        this.cast::<u8>().add(offset).cast::<()>()
    }
}
